import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense, Conv2D, MaxPooling2D, Dropout
import pathlib

# 1. 加载MNIST数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 2. 数据预处理
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.0

# 3. 构建卷积神经网络模型
model = Sequential([
    Conv2D(32, (3,3), activation='relu', input_shape=(28,28,1)),
    MaxPooling2D((2,2)),
    Conv2D(64, (3,3), activation='relu'),
    MaxPooling2D((2,2)),
    Flatten(),
    Dense(64, activation='relu'),
    Dropout(0.2),  # 防止过拟合
    Dense(10, activation='softmax')  # 输出层使用softmax
])

# 4. 编译模型
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# 5. 训练模型
history = model.fit(
    x_train, y_train,
    epochs=5,
    validation_data=(x_test, y_test)
)

# 6. 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"测试准确率: {test_acc:.4f}")

# 7. 转换为TensorFlow Lite模型
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# 8. 保存TFLite模型
tflite_models_dir = pathlib.Path("./mnist_tflite_models/")
tflite_models_dir.mkdir(exist_ok=True, parents=True)
tflite_model_file = tflite_models_dir/"mnist_model.tflite"
tflite_model_file.write_bytes(tflite_model)

# 9. 加载TFLite模型并创建解释器
interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))
interpreter.allocate_tensors()

# 10. 获取输入和输出张量信息
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# 11. 使用TFLite模型进行推理示例
def predict_digit(image):
    # 调整图像形状以匹配模型输入
    input_data = np.expand_dims(image, axis=0)
    
    # 设置输入张量
    interpreter.set_tensor(input_details[0]['index'], input_data)
    
    # 运行推理
    interpreter.invoke()
    
    # 获取输出张量
    output_data = interpreter.get_tensor(output_details[0]['index'])
    return np.argmax(output_data), np.max(output_data)

# 12. 测试TFLite模型
test_image = x_test[0]  # 取第一个测试样本
predicted_digit, confidence = predict_digit(test_image)

# 显示预测结果
plt.imshow(test_image.reshape(28, 28), cmap='gray')
plt.title(f'预测结果: {predicted_digit}, 置信度: {confidence:.2f}')
plt.axis('off')
plt.show()

print(f"预测数字: {predicted_digit}, 置信度: {confidence:.2f}")
print(f"实际数字: {y_test[0]}")
